import os
import glob
import hashlib
import numpy as np
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox

import mne
import matplotlib.pyplot as plt


# -----------------------------
# SETTINGS
# -----------------------------
LEFT_CH = "P3"
RIGHT_CH = "P4"

# Event labels for true left/right conditions
EVENTS_TST = {"R": "TR", "L": "TL"}
EVENTS_GT  = {"R": "GR", "L": "GL"}

# Neutral label (single event name)
EVENT_N_LABEL = "N"

# N2pc peak search window (ms)
PEAK_SEARCH_MS = (180.0, 320.0)

# Optional: extra offset to change the neutral random split globally (keep 0 for stable default)
NEUTRAL_SPLIT_SEED_OFFSET = 0


# -----------------------------
# HELPERS
# -----------------------------
def pick_folder(title):
    root = tk.Tk()
    root.withdraw()
    folder = filedialog.askdirectory(title=title)
    root.destroy()
    return folder


def normalize_event_key(s):
    return str(s).strip().lower()


def find_event_key(event_id, desired_label):
    desired = normalize_event_key(desired_label)
    for k in event_id.keys():
        if normalize_event_key(k) == desired:
            return k
    return None


def evoked_for_event(epochs, event_label):
    key = find_event_key(epochs.event_id, event_label)
    if key is None:
        raise ValueError(
            f"Event '{event_label}' not found. Available keys: {list(epochs.event_id.keys())}"
        )
    return epochs[key].average()


def get_channel_value(evoked, ch_name):
    picks = mne.pick_channels(evoked.ch_names, include=[ch_name])
    if len(picks) != 1:
        raise ValueError(f"Channel {ch_name} not found. Available: {evoked.ch_names}")
    return evoked.data[picks[0], :].copy()


def stable_int_seed_from_string(s: str) -> int:
    """
    Deterministic seed from a string (stable across runs/machines).
    """
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16)  # 32-bit


def compute_contra_ipsi_lr(epochs, event_r, event_l):
    """
    Side-collapsed contra/ipsi for a condition with explicit right/left events.
    Right-target: contra=P3, ipsi=P4
    Left-target:  contra=P4, ipsi=P3

    Collapsed:
      contra = 0.5*(P3_right + P4_left)
      ipsi   = 0.5*(P4_right + P3_left)
      diff   = contra - ipsi
    """
    ev_r = evoked_for_event(epochs, event_r)
    ev_l = evoked_for_event(epochs, event_l)

    # Right-target
    p3_r = get_channel_value(ev_r, LEFT_CH)   # contra
    p4_r = get_channel_value(ev_r, RIGHT_CH)  # ipsi

    # Left-target
    p4_l = get_channel_value(ev_l, RIGHT_CH)  # contra
    p3_l = get_channel_value(ev_l, LEFT_CH)   # ipsi

    contra = 0.5 * (p3_r + p4_l)
    ipsi   = 0.5 * (p4_r + p3_l)
    diff   = contra - ipsi

    times_ms = ev_r.times * 1000.0
    return times_ms, contra * 1e6, ipsi * 1e6, diff * 1e6


def compute_contra_ipsi_neutral_random_split(epochs, neutral_label, subj_for_seed):
    """
    Neutral condition has no true side. We randomly split N trials into two halves:

      Half A ("pseudo-right"): contra=P3, ipsi=P4
      Half B ("pseudo-left"):  contra=P4, ipsi=P3

    Then compute the same side-collapsed contra/ipsi/diff:
      contra = 0.5*(mean(P3 in halfA) + mean(P4 in halfB))
      ipsi   = 0.5*(mean(P4 in halfA) + mean(P3 in halfB))
      diff   = contra - ipsi

    Split is reproducible per participant via a deterministic seed.
    """
    key = find_event_key(epochs.event_id, neutral_label)
    if key is None:
        raise ValueError(
            f"Neutral event '{neutral_label}' not found. Available keys: {list(epochs.event_id.keys())}"
        )

    ep = epochs[key]
    n_trials = len(ep)
    if n_trials < 2:
        raise ValueError(f"Not enough neutral trials to split (n={n_trials}).")

    seed = (stable_int_seed_from_string(subj_for_seed) + int(NEUTRAL_SPLIT_SEED_OFFSET)) % (2**32)
    rng = np.random.RandomState(seed)

    idx = np.arange(n_trials)
    rng.shuffle(idx)

    half = n_trials // 2
    idx_a = idx[:half]      # pseudo-right
    idx_b = idx[half:]      # pseudo-left (may have one extra if odd)

    ep_a = ep[idx_a]
    ep_b = ep[idx_b]

    ev_a = ep_a.average()
    ev_b = ep_b.average()

    # pseudo-right half A: contra=P3, ipsi=P4
    p3_a = get_channel_value(ev_a, LEFT_CH)   # contra
    p4_a = get_channel_value(ev_a, RIGHT_CH)  # ipsi

    # pseudo-left half B: contra=P4, ipsi=P3
    p4_b = get_channel_value(ev_b, RIGHT_CH)  # contra
    p3_b = get_channel_value(ev_b, LEFT_CH)   # ipsi

    contra = 0.5 * (p3_a + p4_b)
    ipsi   = 0.5 * (p4_a + p3_b)
    diff   = contra - ipsi

    times_ms = ev_a.times * 1000.0
    return times_ms, contra * 1e6, ipsi * 1e6, diff * 1e6, seed, len(idx_a), len(idx_b)


def pick_n2pc_peak(times_ms, diff_uV, subj, condition):
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(times_ms, diff_uV, color="black", linewidth=2, label="Contra − Ipsi")
    ax.axvspan(PEAK_SEARCH_MS[0], PEAK_SEARCH_MS[1], color="gray", alpha=0.2)
    ax.axhline(0, linewidth=1)

    ax.set_xlim(PEAK_SEARCH_MS[0] - 50, PEAK_SEARCH_MS[1] + 50)
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Amplitude (µV)")
    ax.set_title(f"{subj} — {condition}\nClick N2pc peak on difference wave")
    ax.legend()

    plt.tight_layout()
    plt.show(block=False)
    click = plt.ginput(1, timeout=-1)
    plt.close(fig)

    if len(click) == 0:
        raise RuntimeError("No peak selected")

    peak_time_ms, peak_amp_uV = click[0]
    return float(peak_time_ms), float(peak_amp_uV)


def interp_at_time(times_ms: np.ndarray, y: np.ndarray, t_ms: float) -> float:
    """
    Linear interpolation of y(times_ms) at t_ms.
    Assumes times_ms is increasing.
    """
    return float(np.interp(t_ms, times_ms, y))


def participant_name_from_filename(path):
    base = os.path.basename(path)
    return base.replace("_manualOutliersRemoved-epo.fif", "").replace(".fif", "")


# -----------------------------
# MAIN
# -----------------------------
def main():
    in_dir = pick_folder("Select folder with *_manualOutliersRemoved-epo.fif files")
    if not in_dir:
        return

    out_dir = pick_folder("Select output folder")
    if not out_dir:
        return

    fif_files = sorted(glob.glob(os.path.join(in_dir, "*_manualOutliersRemoved-epo.fif")))
    if not fif_files:
        messagebox.showerror("Error", "No matching files found")
        return

    rows = []

    for fpath in fif_files:
        subj = participant_name_from_filename(fpath)
        print(f"\nProcessing {subj}")

        epochs = mne.read_epochs(fpath, preload=True, verbose="ERROR")

        # -------------------------
        # TST (manual pick)
        # -------------------------
        times_ms, contra_uV, ipsi_uV, diff_uV = compute_contra_ipsi_lr(
            epochs, EVENTS_TST["R"], EVENTS_TST["L"]
        )

        plt.figure(figsize=(8, 4))
        plt.plot(times_ms, contra_uV, label="Contra")
        plt.plot(times_ms, ipsi_uV, label="Ipsi")
        plt.plot(times_ms, diff_uV, label="Contra − Ipsi", linewidth=2)
        plt.axvspan(*PEAK_SEARCH_MS, alpha=0.2)
        plt.axhline(0, linewidth=1)
        plt.xlabel("Time (ms)")
        plt.ylabel("Amplitude (µV)")
        plt.title(f"{subj} — TST")
        plt.legend()
        plt.tight_layout()
        plt.show()

        tst_t, tst_a = pick_n2pc_peak(times_ms, diff_uV, subj, "TST")
        rows.append({
            "participant": subj,
            "condition": "TST",
            "peak_method": "manual",
            "peak_time_ms": tst_t,
            "peak_amp_uV": tst_a,
        })

        # -------------------------
        # GT (manual pick)
        # -------------------------
        times_ms2, contra_uV2, ipsi_uV2, diff_uV2 = compute_contra_ipsi_lr(
            epochs, EVENTS_GT["R"], EVENTS_GT["L"]
        )

        plt.figure(figsize=(8, 4))
        plt.plot(times_ms2, contra_uV2, label="Contra")
        plt.plot(times_ms2, ipsi_uV2, label="Ipsi")
        plt.plot(times_ms2, diff_uV2, label="Contra − Ipsi", linewidth=2)
        plt.axvspan(*PEAK_SEARCH_MS, alpha=0.2)
        plt.axhline(0, linewidth=1)
        plt.xlabel("Time (ms)")
        plt.ylabel("Amplitude (µV)")
        plt.title(f"{subj} — GT")
        plt.legend()
        plt.tight_layout()
        plt.show()

        gt_t, gt_a = pick_n2pc_peak(times_ms2, diff_uV2, subj, "GT")
        rows.append({
            "participant": subj,
            "condition": "GT",
            "peak_method": "manual",
            "peak_time_ms": gt_t,
            "peak_amp_uV": gt_a,
        })

        # Average peak time between TST and GT
        avg_t = 0.5 * (tst_t + gt_t)

        # -------------------------
        # N (auto amplitude at avg_t)
        # -------------------------
        n_times_ms, n_contra_uV, n_ipsi_uV, n_diff_uV, seed_used, nA, nB = compute_contra_ipsi_neutral_random_split(
            epochs, EVENT_N_LABEL, subj_for_seed=subj
        )

        # Compute N amplitude at avg_t (interpolated)
        n_amp_at_avg_t = interp_at_time(n_times_ms, n_diff_uV, avg_t)

        # Plot N with marker/line at avg_t
        plt.figure(figsize=(8, 4))
        plt.plot(n_times_ms, n_contra_uV, label="Contra")
        plt.plot(n_times_ms, n_ipsi_uV, label="Ipsi")
        plt.plot(n_times_ms, n_diff_uV, label="Contra − Ipsi", linewidth=2)
        plt.axvspan(*PEAK_SEARCH_MS, alpha=0.2)
        plt.axhline(0, linewidth=1)

        # Visual indicator of the auto-sampled time
        plt.axvline(avg_t, linewidth=1)
        plt.plot([avg_t], [n_amp_at_avg_t], marker="o")

        plt.xlabel("Time (ms)")
        plt.ylabel("Amplitude (µV)")
        plt.title(f"{subj} — N (auto @ avg(TST,GT)={avg_t:.1f} ms; seed={int(seed_used)}, nA={int(nA)}, nB={int(nB)})")
        plt.legend()
        plt.tight_layout()
        plt.show()

        rows.append({
            "participant": subj,
            "condition": "N",
            "peak_method": "auto_avg_TST_GT_time",
            "peak_time_ms": float(avg_t),
            "peak_amp_uV": float(n_amp_at_avg_t),
            "neutral_label_used": EVENT_N_LABEL,
            "neutral_split_seed": int(seed_used),
            "neutral_n_pseudo_right": int(nA),
            "neutral_n_pseudo_left": int(nB),
            "tst_manual_time_ms": float(tst_t),
            "gt_manual_time_ms": float(gt_t),
        })

    os.makedirs(out_dir, exist_ok=True)
    out_csv = os.path.join(out_dir, "n2pc_peaks_TST_GT_manual_N_auto.csv")
    pd.DataFrame(rows).to_csv(out_csv, index=False)

    messagebox.showinfo("Done", f"Saved peaks to:\n{out_csv}")


if __name__ == "__main__":
    main()
